---
title: Алгоритм Винограда — Штрассена
description: Рассмотрим модификацию алгоритма Штрассена для умножения квадратных матриц с меньшим количеством сложений между блоками, чем в обычном алгоритме — 15 вместо...
sections: [Многопоточность,Блочные матрицы,Сравнение алгоритмов]
tags: [java,потоки,массивы,многомерные массивы,матрицы,рекурсия,циклы,вложенные циклы]
canonical_url: /ru/2022/02/10/winograd-strassen-algorithm.html
url_translated: /en/2022/02/11/winograd-strassen-algorithm.html
title_translated: Winograd — Strassen algorithm
date: 2022.02.10
---

Рассмотрим модификацию алгоритма Штрассена для умножения квадратных матриц с *меньшим* количеством
сложений между блоками, чем в обычном алгоритме — 15 вместо 18 и таким же количеством умножений как
в обычном алгоритме — 7. Будем использовать потоки Java Stream.

Рекурсивное дробление матриц на блоки при перемножении имеет смысл до определенного предела, а дальше
теряет смысл, так как алгоритм Штрассена не использует кеш среды выполнения. Поэтому для малых блоков
будем использовать параллельный вариант вложенных циклов, а для больших блоков параллельно будем
выполнять рекурсивное дробление.

Границу между двумя алгоритмами определяем экспериментально — подстраиваем под кеш среды выполнения.
Выгода алгоритма Штрассена становится заметнее на больших матрицах — отличие от алгоритма на вложенных
циклах становится больше и зависит от среды выполнения. Сравним время работы двух алгоритмов.

*Алгоритм на трёх вложенных циклах: [Оптимизация умножения матриц]({{ '/ru/2021/12/09/optimizing-matrix-multiplication.html' | relative_url }}).*

## Описание алгоритма {#algorithm-description}

Матрицы должны быть одинакового размера. Разделяем каждую матрицу на 4 равных блока. Блоки должны быть
квадратными, поэтому если это не так, тогда сначала дополняем матрицы нулевыми строками и столбцами,
а после этого разделяем на блоки. Лишние строки и столбцы потом уберём из результирующей матрицы.

{% include image_svg.html src="/img/block-matrices.svg" style="width:221pt; height:33pt;"
alt="{\displaystyle A={\begin{pmatrix}A_{11}&A_{12}\\A_{21}&A_{22}\end{pmatrix}},\quad B={\begin{pmatrix}B_{11}&B_{12}\\B_{21}&B_{22}\end{pmatrix}}.}" %}

Суммируем блоки.

{% include image_svg.html src="/img/sums1.svg" style="width:101pt; height:148pt;"
alt="{\displaystyle{\begin{aligned}S_{1}&=(A_{21}+A_{22});\\S_{2}&=(S_{1}-A_{11});\\S_{3}&=(A_{11}-A_{21});\\S_{4}&=(A_{12}-S_{2});\\S_{5}&=(B_{12}-B_{11});\\S_{6}&=(B_{22}-S_{5});\\S_{7}&=(B_{22}-B_{12});\\S_{8}&=(S_{6}-B_{21}).\end{aligned}}}" %}

Умножаем блоки.

{% include image_svg.html src="/img/products.svg" style="width:75pt; height:127pt;"
alt="{\displaystyle{\begin{aligned}P_{1}&=S_{2}S_{6};\\P_{2}&=A_{11}B_{11};\\P_{3}&=A_{12}B_{21};\\P_{4}&=S_{3}S_{7};\\P_{5}&=S_{1}S_{5};\\P_{6}&=S_{4}B_{22};\\P_{7}&=A_{22}S_{8}.\end{aligned}}}" %}

Суммируем блоки.

{% include image_svg.html src="/img/sums2.svg" style="width:78pt; height:31pt;"
alt="{\displaystyle{\begin{aligned}T_{1}&=P_{1}+P_{2};\\T_{2}&=T_{1}+P_{4}.\end{aligned}}}" %}

Блоки результирующей матрицы.

{% include image_svg.html src="/img/sums3.svg" style="width:240pt; height:33pt;"
alt="{\displaystyle{\begin{pmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{pmatrix}}={\begin{pmatrix}P_{2}+P_{3}&T_{1}+P_{5}+P_{6}\\T_{2}-P_{7}&T_{2}+P_{5}\end{pmatrix}}.}" %}

## Гибридный алгоритм {#hybrid-algorithm}

Каждую матрицу `A` и `B` делим на 4 *равных* блока, при необходимости дополняем недостающие части
нулями. Выполняем 15 сложений и 7 умножений над блоками — получаем 4 блока матрицы `C`. Убираем
лишние нули, если добавляли, и возвращаем результирующую матрицу. Рекурсивное дробление больших
блоков запускаем в параллельном режиме, а для малых блоков вызываем алгоритм на вложенных циклах.

```java
/**
 * @param n   размер матрицы
 * @param brd минимальный размер матрицы
 * @param a   первая матрица 'n×n'
 * @param b   вторая матрица 'n×n'
 * @return результирующая матрица 'n×n'
 */
public static int[][] multiplyMatrices(int n, int brd, int[][] a, int[][] b) {
    // малые блоки перемножаем с помощью алгоритма на вложенных циклах
    if (n < brd) return simpleMultiplication(n, a, b);
    // серединная точка матрицы, округляем в большую сторону — блоки должны
    // быть квадратными, при необходимости добавляем нулевые строки и столбцы
    int m = n - n / 2;
    // блоки первой матрицы
    int[][] a11 = getQuadrant(m, n, a, true, true);
    int[][] a12 = getQuadrant(m, n, a, true, false);
    int[][] a21 = getQuadrant(m, n, a, false, true);
    int[][] a22 = getQuadrant(m, n, a, false, false);
    // блоки второй матрицы
    int[][] b11 = getQuadrant(m, n, b, true, true);
    int[][] b12 = getQuadrant(m, n, b, true, false);
    int[][] b21 = getQuadrant(m, n, b, false, true);
    int[][] b22 = getQuadrant(m, n, b, false, false);
    // суммируем блоки
    int[][] s1 = sumMatrices(m, a21, a22, true);
    int[][] s2 = sumMatrices(m, s1, a11, false);
    int[][] s3 = sumMatrices(m, a11, a21, false);
    int[][] s4 = sumMatrices(m, a12, s2, false);
    int[][] s5 = sumMatrices(m, b12, b11, false);
    int[][] s6 = sumMatrices(m, b22, s5, false);
    int[][] s7 = sumMatrices(m, b22, b12, false);
    int[][] s8 = sumMatrices(m, s6, b21, false);
    int[][][] p = new int[7][][];
    // перемножаем блоки в параллельных потоках
    IntStream.range(0, 7).parallel().forEach(i -> {
        switch (i) { // рекурсивные вызовы
            case 0: p[i] = multiplyMatrices(m, brd, s2, s6); break;
            case 1: p[i] = multiplyMatrices(m, brd, a11, b11); break;
            case 2: p[i] = multiplyMatrices(m, brd, a12, b21); break;
            case 3: p[i] = multiplyMatrices(m, brd, s3, s7); break;
            case 4: p[i] = multiplyMatrices(m, brd, s1, s5); break;
            case 5: p[i] = multiplyMatrices(m, brd, s4, b22); break;
            case 6: p[i] = multiplyMatrices(m, brd, a22, s8); break;
        }
    });
    // суммируем блоки
    int[][] t1 = sumMatrices(m, p[0], p[1], true);
    int[][] t2 = sumMatrices(m, t1, p[3], true);
    // блоки результирующей матрицы
    int[][] c11 = sumMatrices(m, p[1], p[2], true);
    int[][] c12 = sumMatrices(m, t1, sumMatrices(m, p[4], p[5], true), true);
    int[][] c21 = sumMatrices(m, t2, p[6], false);
    int[][] c22 = sumMatrices(m, t2, p[4], true);
    // собираем результирующую матрицу из блоков,
    // убираем нулевые строки и столбцы, если добавляли
    return putQuadrants(m, n, c11, c12, c21, c22);
}
```

{% capture collapsed_md %}
```java
// вспомогательный метод для суммирования матриц
private static int[][] sumMatrices(int n, int[][] a, int[][] b, boolean sign) {
    int[][] c = new int[n][n];
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            c[i][j] = sign ? a[i][j] + b[i][j] : a[i][j] - b[i][j];
    return c;
}
```
```java
// вспомогательный метод, получает блок матрицы
private static int[][] getQuadrant(int m, int n, int[][] x,
                                   boolean first, boolean second) {
    int[][] q = new int[m][m];
    if (first) for (int i = 0; i < m; i++)
        if (second) System.arraycopy(x[i], 0, q[i], 0, m); // x11
        else System.arraycopy(x[i], m, q[i], 0, n - m); // x12
    else for (int i = m; i < n; i++)
        if (second) System.arraycopy(x[i], 0, q[i - m], 0, m); // x21
        else System.arraycopy(x[i], m, q[i - m], 0, n - m); // x22
    return q;
}
```
```java
// вспомогательный метод, собирает матрицу из блоков
private static int[][] putQuadrants(int m, int n,
                                    int[][] x11, int[][] x12,
                                    int[][] x21, int[][] x22) {
    int[][] x = new int[n][n];
    for (int i = 0; i < n; i++)
        if (i < m) {
            System.arraycopy(x11[i], 0, x[i], 0, m);
            System.arraycopy(x12[i], 0, x[i], m, n - m);
        } else {
            System.arraycopy(x21[i - m], 0, x[i], 0, m);
            System.arraycopy(x22[i - m], 0, x[i], m, n - m);
        }
    return x;
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Вспомогательные методы" content=collapsed_md -%}

## Вложенные циклы {#nested-loops}

Для дополнения предыдущего алгоритма и для сравнения с ним же будем использовать *оптимизированный*
вариант вложенных циклов, который лучше прочих использует кеш среды выполнения — обработка строк
результирующей матрицы выполняется независимо друг от друга в параллельных потоках. Для малых матриц
будем использовать этот алгоритм — большие матрицы делим на малые блоки и используем этот же алгоритм.

```java
/**
 * @param n размер матрицы
 * @param a первая матрица 'n×n'
 * @param b вторая матрица 'n×n'
 * @return результирующая матрица 'n×n'
 */
public static int[][] simpleMultiplication(int n, int[][] a, int[][] b) {
    // результирующая матрица
    int[][] c = new int[n][n];
    // обходим индексы строк матрицы 'a' в параллельном режиме
    IntStream.range(0, n).parallel().forEach(i -> {
        // обходим индексы общей стороны двух матриц:
        // колонок матрицы 'a' и строк матрицы 'b'
        for (int k = 0; k < n; k++)
            // обходим индексы колонок матрицы 'b'
            for (int j = 0; j < n; j++)
                // сумма произведений элементов i-ой строки
                // матрицы 'a' и j-ой колонки матрицы 'b'
                c[i][j] += a[i][k] * b[k][j];
    });
    return c;
}
```

## Тестирование {#testing}

Для проверки возьмём две квадратные матрицы `A=[1000×1000]` и `B=[1000×1000]`, заполненные случайными
числами. Минимальный размер блока возьмём `[200×200]` элементов. Сначала сравниваем между собой
корректность реализации двух алгоритмов — произведения матриц должны совпадать. Затем выполняем каждый
метод по 10 раз и подсчитываем среднее время выполнения.

```java
// запускаем программу и выводим результат
public static void main(String[] args) {
    // входящие данные
    int n = 1000, brd = 200, steps = 10;
    int[][] a = randomMatrix(n, n), b = randomMatrix(n, n);
    // произведения матриц
    int[][] c1 = multiplyMatrices(n, brd, a, b);
    int[][] c2 = simpleMultiplication(n, a, b);
    // проверяем корректность результатов
    System.out.println("Результаты совпадают: " + Arrays.deepEquals(c1, c2));
    // замеряем время работы двух методов
    benchmark("Гибридный алгоритм", steps, () -> {
        int[][] c = multiplyMatrices(n, brd, a, b);
        if (!Arrays.deepEquals(c, c1)) System.out.print("ошибка");
    });
    benchmark("Вложенные циклы   ", steps, () -> {
        int[][] c = simpleMultiplication(n, a, b);
        if (!Arrays.deepEquals(c, c2)) System.out.print("ошибка");
    });
}
```

{% capture collapsed_md %}
```java
// вспомогательный метод, возвращает матрицу указанного размера
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
```
```java
// вспомогательный метод для замера времени работы переданного кода
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // время выполнения одного шага
        System.out.print(" | " + time);
        avg += time;
    }
    // среднее время выполнения
    System.out.println(" || " + (avg / steps));
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Вспомогательные методы" content=collapsed_md -%}

Вывод зависит от среды выполнения, время в миллисекундах:

```
Результаты совпадают: true
Гибридный алгоритм | 196 | 177 | 156 | 205 | 154 | 165 | 133 | 118 | 132 | 134 || 157
Вложенные циклы    | 165 | 164 | 168 | 167 | 168 | 168 | 170 | 179 | 173 | 168 || 169
```

## Сравнение алгоритмов {#comparing-algorithms}

На восьмиядерном компьютере Linux x64 запускаем вышеописанный тест 100 раз вместо 10. Минимальный 
размер блока берём `[brd=200]`. Изменяем только `n` — размеры обеих матриц `A=[n×n]` и `B=[n×n]`.
Получаем сводную таблицу результатов. Время в миллисекундах.

```
                 n | 900 | 1000 | 1100 | 1200 | 1300 | 1400 | 1500 | 1600 | 1700 |
-------------------|-----|------|------|------|------|------|------|------|------|
Гибридный алгоритм |  96 |  125 |  169 |  204 |  260 |  313 |  384 |  482 |  581 |
Вложенные циклы    | 119 |  162 |  235 |  281 |  361 |  497 |  651 |  793 |  971 |
```

Результаты: выгода алгоритма Штрассена становится заметнее на больших матрицах, когда размер
самой матрицы в несколько раз превышает размер минимального блока, и зависит от среды выполнения.

Все описанные выше методы, включая свёрнутые блоки, можно поместить в одном классе.

{% capture collapsed_md %}
```java
import java.util.Arrays;
import java.util.stream.IntStream;
```
{% endcapture %}
{%- include collapsed_block.html summary="Необходимые импорты" content=collapsed_md -%}
